{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a03f354c",
   "metadata": {},
   "source": [
    "# 1 概述"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "065c462f",
   "metadata": {},
   "source": [
    "训练过程的可视化在深度学习模型训练中扮演着重要的角色。谷歌为Tensorflow打造可视化工具Tensorboard，而Facebook为PyTorch开发了一款可视化工具，名为**Visdom**。Visdom十分轻量级，却支持非常丰富的功能，能胜任绝大多数的科学运算可视化任务，其可视化结果如下图所示。\n",
    "\n",
    "<img src=\"img/visdom.png\" style=\"zoom:30%\">\n",
    "\n",
    "![sample2](img/plots.gif)\n",
    "\n",
    "我们可以看到Visdom可以帮助我们展示数据的分布，模型的训练、模型结构、参数分布等，这些对于我们在debug中查找问题来源非常重要。更多的介绍我们可以参考下方的两个链接：\n",
    "- Visdom Github链接：[Github](https://github.com/fossasia/visdom)\n",
    "- Visdom 官方网站：[官网](https://ai.facebook.com/tools/visdom/)\n",
    "\n",
    "经过本节课的学习，你将收获：\n",
    "- 如何安装和使用Visdom\n",
    "- 了解Visdom基本知识\n",
    "- 使用Visdom进行绘图操作\n",
    "- 利用Visdom可视化训练过程\n",
    "\n",
    "*准备好了吗？按照以下步骤开始吧！*\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06d21cb1",
   "metadata": {},
   "source": [
    "# 2 安装和使用Visdom"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eaea6415",
   "metadata": {},
   "source": [
    "我们可以通过该pip命令来安装visdom `pip install visdom`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b56e337c",
   "metadata": {},
   "source": [
    "安装后，我们该如何启动Visdom呢？\n",
    "\n",
    "可以通过下方命令来启动，第一次启动时会下载一些相关文件。初次启动后最好在终端重新输入一次指令看看能否正常启动。\n",
    "\n",
    "```shell\n",
    "python -m visdom.server  # 或直接输入 visdom\n",
    "nohup python -m visdom.server &  # 还可以使用该命令将服务放到后台运行\n",
    "```\n",
    "\n",
    "如果能正常启动，终端将会显示如下信息：\n",
    "\n",
    "<img src=\"img/start.png\" style=\"zoom:30%\">\n",
    "\n",
    "复制`http://localhost:8097`到浏览器后，发现此时的界面并没有显示任何信息，如下图所示。\n",
    "\n",
    "<img src=\"img/UI.png\" style=\"zoom:30%\">\n",
    "\n",
    "没关系，在随后的内容中，我们将学习如何使用多个Panes（窗格）来填充它，并且这些窗格可以进行缩放、移动、删除等操作。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "84baf708-2e42-4e10-9129-e08bba214f86",
   "metadata": {},
   "source": [
    "顶部的按钮含义如下图所示：\n",
    "\n",
    "<img src=\"img/UI2.png\" style=\"zoom:30%\">\n",
    "\n",
    "- 注意clear操作需双击。\n",
    "- 在状态为“offline”时，无法保存/删除/清空环境。只能进行过滤筛选操作。\n",
    "- 点击管理（外观为文件夹的icon）按钮后，弹出以下页面，可保存或删除当前环境的视图内容。\n",
    "\n",
    "<img src=\"img/UI3.png\" style=\"zoom:30%\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eef1cbc5",
   "metadata": {},
   "source": [
    "# 3 Visdom基本知识"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc74191e",
   "metadata": {},
   "source": [
    "Visdom可以创建，共享多种数据形式的可视化，包括数值，图像，文本和视频，支持PyTorch，Numpy等接口。Visdom中主要有以下几个重要概念。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46005a50",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-31T14:29:08.537782Z",
     "start_time": "2023-03-31T14:29:08.519790Z"
    }
   },
   "source": [
    "## 3.1 env（环境）"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ee70164",
   "metadata": {},
   "source": [
    "Environment 对可视化的区域进行分区，这样使得不同环境的可视化结果相互隔离，互不影响，在使用时如果不指定特定的env，默认将会使用main默认环境.\n",
    "\n",
    "我们可以通过编程或UI创建新的env。不同用户、不同程序一般使用不同env。\n",
    "\n",
    "这样做可以让我们通过分享url: http://localhost.com:8097/env/env_name 让其他人访问特定的env。\n",
    "\n",
    "我们的 envs 默认通过`$HOME/.visdom/ `加载。也可以将自定义的路径当作命令行参数传入。\n",
    "\n",
    "\n",
    "不要轻易移除目录下的`env_name.json`文件，这将导致相应的环境也会被删除。\n",
    "\n",
    "可以通过以下代码建立新的environment。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb6273e7",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-31T14:31:36.830714Z",
     "start_time": "2023-03-31T14:31:36.810713Z"
    }
   },
   "outputs": [],
   "source": [
    "import visdom \n",
    "vis = visdom.Visdom(env='pytorchenv')\n",
    "# vis = visdom.Visdom(env=env_name)的作用是构建一个客户端\n",
    "# env_name 是指定的环境的名称（字符串），也可以指定host，port等其他参数"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1e36476e-14e7-43be-91e7-6885fa92a92f",
   "metadata": {},
   "source": [
    "## 3.2 pane（窗格）"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8636f3aa-d929-4132-ae88-a6adb18485cd",
   "metadata": {},
   "source": [
    "Pane可以理解为用于可视化图表、图片、文本、视频的容器。一个环境里可以使用不同的窗格来可视化或记录某一信息。\n",
    "\n",
    "我们可以对pane进行拖放，删除，调整大小和销毁等操作。\n",
    "\n",
    "一个程序既可以使用同一个env中的不同pane，也可以通过指定`win`来使用同一个env中的pane。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e367ee44-2da1-428d-a28c-688b1a1c862d",
   "metadata": {},
   "source": [
    "# 4 使用Visdom绘图"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f71e32c-a9c5-48cf-a2d4-de314485f089",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 4.1 基本绘制\n",
    "在接下来的示例中，我们将围绕常见的line、image、text等操作进行介绍。\n",
    "\n",
    "还需要注意的是visdom仅仅支持PyTorch的tensor和numpy的ndarray的数据结构，不支持Python中的int、float等类型，因此在传入前应该确保我们的数据格式是tensor或numpy。\n",
    "\n",
    "下面我们来运行一下最基础的示例吧~！"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b133851f-6da2-4877-8a8b-1d9ba93de862",
   "metadata": {},
   "outputs": [],
   "source": [
    "import visdom\n",
    "import numpy as np\n",
    "vis = visdom.Visdom(env = 'pytorchenv')  # 指定使用的环境，若不指定将默认使用main\n",
    "vis.text('Hello, world!')  # 输出文本字符\n",
    "vis.image(np.zeros((3, 224, 224)))  # 输出大小为3*224*224（CxHxW）大小的黑色图片"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b36a548d-1cbc-4d7e-aff0-8a7f7a502148",
   "metadata": {},
   "source": [
    "我们发现原来的界面出现了一个文本和一幅图像\n",
    "\n",
    "<img src=\"img/demo1.png\" style=\"zoom:40%\">\n",
    "\n",
    "如下图所示，左上角的图标从左到右分别代表“关闭”、“下载”和“刷新”。点击右下角按钮可进行拖拽操作。\n",
    "\n",
    "<img src=\"img/demo1ex.png\" style=\"zoom:30%\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fec7d956-b27a-4036-b3e5-deed47d6f33e",
   "metadata": {},
   "source": [
    "除此以外，我们通常会传入win和opt来进行设置。\n",
    "\n",
    "- `win`：用于指定pane的名字，若不指定，visdom会自动分配给我们一个新的pane。但是我们一般需要在原始图片上修改。因此建议每次操作都指定win。\n",
    "- `opts`：用来可视化配置，接收一个字典，常见的option包括title，xlabel，ylabel，width等，用来设置pane的显示格式。\n",
    "- `append`：在visdom中，每次操作都会覆盖前面的值，但在可视化损失函数时往往需要不断更新数值且不覆盖前面的数值，这时，只需要传入`update = 'append'`这个参数来避免覆盖之前的数值即可。\n",
    "\n",
    "再来尝试一下其他案例吧~！"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a78b9dc-52bd-4ac7-815a-dcc1e1a86fca",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import visdom as vis\n",
    "vis = vis.Visdom(env='pytorchenv')\n",
    "x = torch.arange(1,30,0.01)\n",
    "y = torch.sin(x)\n",
    "vis.line(X=x,Y=y,win='sinx',opts={'title':'y.sin(x)'})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2da5e16c-79d9-44f6-b071-36335ed34e24",
   "metadata": {},
   "source": [
    "结果如下所示：\n",
    "\n",
    "<img src=\"img/demo2.png\" style=\"zoom=30%\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6581ca29-ec70-4da5-bc55-7379f4627fb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import visdom as vis\n",
    "vis = vis.Visdom(env='pytorchenv')\n",
    "vis.image(torch.randn(64,64),win='rand1')  #可视化一张随机的黑白图片\n",
    "vis.image(torch.randn(3,64,64),win='rand2')  #可视化一张随机的彩色图片\n",
    "vis.images(torch.randn(36,3,64,64).numpy(), nrow=6, win='rand3', opts={'title':'demo'})  #可视化36张随机彩色图片，每一张6行"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d339460-275c-492c-8598-087ba7771e80",
   "metadata": {},
   "source": [
    "结果如下所示：\n",
    "\n",
    "<img src=\"img/demo3.png\" style=\"zoom=30%\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d76e358-cb6d-49aa-83ab-cbe5e0d4e821",
   "metadata": {},
   "source": [
    "除了text、line、image、images以外，Visdom还支持以下基本可视化函数：\n",
    "\n",
    "- vis.image : 图片\n",
    "- vis.line: 曲线\n",
    "- vis.images : 图片列表\n",
    "- vis.text : 抽象HTML 输出文字\n",
    "- vis.properties : 属性网格\n",
    "- vis.audio : 音频\n",
    "- vis.video : 视频\n",
    "- vis.svg : SVG对象\n",
    "- vis.matplot : matplotlib图\n",
    "- vis.save : 序列化状态服务端\n",
    "\n",
    "上述函数可传入的参数：\n",
    "- opts.title : 图标题\n",
    "- win : 窗口名称\n",
    "- opts.width : 图宽\n",
    "- opts.height : 图高\n",
    "- opts.showlegend : 显示图例 (true or false)\n",
    "- opts.xtype : x轴的类型 ('linear' or 'log')\n",
    "- opts.xlabel : x轴的标签\n",
    "- opts.xtick : 显示x轴上的刻度 (boolean)\n",
    "- opts.xtickmin : 指定x轴上的第一个刻度 (number)\n",
    "- opts.xtickmax : 指定x轴上的最后一个刻度 (number)\n",
    "- opts.xtickvals : x轴上刻度的位置(table of numbers)\n",
    "- opts.xticklabels : 在x轴上标记标签 (table of strings)\n",
    "- opts.xtickstep : x轴上刻度之间的距离 (number)\n",
    "- opts.xtickfont :x轴标签的字体 (dict of font information)\n",
    "- 有关y轴的参数只需将上述x换成y即可\n",
    "- opts.marginleft : 左边框 (in pixels)\n",
    "- opts.marginright :右边框 (in pixels)\n",
    "- opts.margintop : 上边框 (in pixels)\n",
    "- opts.marginbottom : 下边框 (in pixels)\n",
    "- opts.lagent=[''] : 显示图标"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c6dfc69-1911-407c-946c-2ce85b6fa54f",
   "metadata": {},
   "source": [
    "## 4.2 其他图表绘制\n",
    "\n",
    "我们在第一步创建的`Visdom`的实例`vis`支持以下画图函数，这些函数接口由`Plotly`所提供。\n",
    "\n",
    "-   [`vis.scatter`](https://github.com/fossasia/visdom#visscatter) : 绘制2D 或 3D 散点图\n",
    "-   [`vis.line`](https://github.com/fossasia/visdom#visline) : 线形图\n",
    "-   [`vis.stem`](https://github.com/fossasia/visdom#visstem) : 茎状图\n",
    "-   [`vis.heatmap`](https://github.com/fossasia/visdom#visheatmap) : 热力图\n",
    "-   [`vis.bar`](https://github.com/fossasia/visdom#visbar) : 柱状图\n",
    "-   [`vis.histogram`](https://github.com/fossasia/visdom#vishistogram): 直方图\n",
    "-   [`vis.boxplot`](https://github.com/fossasia/visdom#visboxplot) : 箱线图\n",
    "-   [`vis.surf`](https://github.com/fossasia/visdom#vissurf) : 曲面图\n",
    "-   [`vis.contour`](https://github.com/fossasia/visdom#viscontour) : 等高线图\n",
    "-   [`vis.quiver`](https://github.com/fossasia/visdom#visquiver) : 折线图\n",
    "-   [`vis.mesh`](https://github.com/fossasia/visdom#vismesh) : 网格图\n",
    "-   [`vis.dual_axis_lines`](https://github.com/fossasia/visdom#visdual_axis_lines) : 双 y 轴线图\n",
    "\n",
    "下面将以柱状图的绘制为例。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec55fa01-e166-4f45-8065-a06710424932",
   "metadata": {},
   "outputs": [],
   "source": [
    "from visdom import Visdom\n",
    "import numpy as np\n",
    " \n",
    "vis = Visdom(env='pytorchenv')\n",
    "vis.bar(X=np.random.rand(4, 2),\n",
    "        win='test1',\n",
    "        opts=dict(\n",
    "        stacked=False,  # 是否堆叠柱形（若为False，效果如下图demo1所示；若为True，效果如下图demo2所示\n",
    "        legend=['A', 'B'],   # 图例标签名称\n",
    "        rownames=['top1', 'top5', 'top10', 'top20'],  # 列名称\n",
    "        title='demo1',  # 图表标题\n",
    "        ylabel='rank-k  Error Rate',  # y轴名称\n",
    "        xtickmin=0.4,  # x轴左端点起始位置\n",
    "        xtickstep=0.4  # 每个柱形间隔距离\n",
    "    ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdb76dc3-5223-4993-8759-ccb1ed3c7edc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from visdom import Visdom\n",
    "import numpy as np\n",
    " \n",
    "vis = Visdom(env='pytorchenv')\n",
    "vis.bar(X=np.random.rand(4, 2),\n",
    "        win='test2',\n",
    "        opts=dict(\n",
    "        stacked=True,  # 是否堆叠柱形（若为False，效果如下图demoA所示；若为True，效果如下图demoB所示\n",
    "        legend=['A', 'B'],   # 图例标签名称\n",
    "        rownames=['top1', 'top5', 'top10', 'top20'],  # 列名称\n",
    "        title='demo2',  # 图表标题\n",
    "        ylabel='rank-k  Error Rate',  # y轴名称\n",
    "        xtickmin=0.4,  # x轴左端点起始位置\n",
    "        xtickstep=0.4  # 每个柱形间隔距离\n",
    "    ))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb361dc9-355e-4aa1-b5ff-30aa113d8ec2",
   "metadata": {},
   "source": [
    "输出结果如下图所示：\n",
    "<img src=\"img/demo4.png\" style=\"zoom=50%\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22d85d78-a6c2-4a2a-b743-127723890530",
   "metadata": {},
   "source": [
    "## 4.3 可视化图片\n",
    "在处理图像数据时，可以使用Visdom对图像进行可视化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9eedb5da-8e5e-4615-9102-b81c62299e40",
   "metadata": {},
   "outputs": [],
   "source": [
    "import visdom  \n",
    "from PIL import Image  \n",
    "import torchvision.transforms.functional as F  \n",
    "  \n",
    "vis = visdom.Visdom(env='pytorchenv')  \n",
    "img = Image.open('img/Lenna.jpg')  \n",
    "img_tensor = F.to_tensor(img)  # 将图像转为tensor类型\n",
    "print(img_tensor.shape)  # 输出图片大小，可省略\n",
    "vis.image(img_tensor, win='photo')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df668052-57a4-4e39-aab1-86d2a0395e5d",
   "metadata": {},
   "source": [
    "输出结果如下图所示：\n",
    "<img src=\"img/photo.png\" style=\"zoom=50%\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0a9cba8-9044-4c36-a06c-59cb28ec62d6",
   "metadata": {},
   "source": [
    "看完以上内容，你是否依然摸不着头脑，还是不知该如何运用这些函数与参数？没关系，在接下来的学习中我们会用一些实例帮助大家掌握。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85e177b8-2000-4f45-8758-e3e164181b8e",
   "metadata": {},
   "source": [
    "# 5 利用Visdom可视化训练过程"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e4aeae3-3725-4a59-add7-60b89f9e08d1",
   "metadata": {},
   "source": [
    "经过上述学习，相信大家已经对Visdom有了一个初步的了解，在接下来的这部分中，我们将通过具体地案例来帮助大家通过Visdom更好地查看损失函数变化。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8941c44-ab2b-4485-8f67-ffa8d107404a",
   "metadata": {},
   "source": [
    "## 5.1 绘制实时曲线"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "956b73a4-e3a6-43b7-a915-da6f1572fd6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 单条曲线绘制\n",
    "import visdom\n",
    "vis = visdom.Visdom(env=\"pytorchenv\")\n",
    "'''起点'''\n",
    "vis.line([0.],     #第一个点的Y坐标\n",
    "         [0.],     #第一个点的X坐标\n",
    "         win='train loss',  # 窗口名称\n",
    "         opts=dict(title = 'train_loss',xlabel='episodes',ylabel='loss') #图标题、x轴和Y轴标签\n",
    "         ) #设置起点\n",
    "'''模型数据'''\n",
    "vis.line([1.],[1.],       #下一点的Y坐标及X坐标\n",
    "         win='train loss',  # 窗口名称 与上个窗口同名表示显示在同一个表格里\n",
    "         update='append')  # 添加到上一个点的后面"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3285c51e-baec-475e-b679-0c3e134431aa",
   "metadata": {},
   "source": [
    "结果如下图所示：\n",
    "<img src=\"img/demo5.png\" style=\"zoom=50%\">\n",
    "\n",
    "- 点击右上角的标签按钮，出现详细的属性信息。\n",
    "- 鼠标悬浮在图片上方，可以进行更多操作，如放大、缩小、下载为png图片等。\n",
    "- 点击右下角的“edit”，还可对页面进行编辑操作。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b32f741-e31b-4451-92fd-560093084dda",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 多条曲线绘制 实际上就是传入y值时为一个向量\n",
    "vis = visdom.Visdom(env=\"pytorchenv\")\n",
    "\n",
    "'''起点'''\n",
    "vis.line([[0.0,0.0]],    # Y的起始点\n",
    "          [0.],    # X的起始点\n",
    "         win=\"test loss\",    #窗口名称\n",
    "         opts=dict(title='test_loss')  # 图像标例\n",
    "        )\n",
    "'''模型数据'''\n",
    "vis.line([[1.1,1.5]],   # Y的下一个点\n",
    "        [1.],   # X的下一个点\n",
    "        win=\"test loss\",  # 窗口名称\n",
    "        update='append')   # 添加到上一个点后面"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c38e6bf2-717c-41e6-a220-76eb5e23ee5c",
   "metadata": {},
   "source": [
    "输出结果如下图所示：\n",
    "<img src=\"img/demo6.png\" style=\"zoom=30%\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cbbeee90-9e98-4151-882d-95b8d93bf309",
   "metadata": {},
   "source": [
    "## 5.2 初识可视化训练过程"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "adc96d23-57c4-40ef-aa76-e831c3930c00",
   "metadata": {},
   "source": [
    "为方便学习，我们使用自带的MNIST数据进行可视化训练过程的展示。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a90aae67-b4ec-4a06-8e32-d9f30c1fbdbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "导入库文件\n",
    "'''\n",
    "import  torch\n",
    "import  torch.nn as nn\n",
    "import  torch.nn.functional as F\n",
    "import  torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "import visdom\n",
    "import numpy as np\n",
    "'''\n",
    "构建简单的模型:简单线性层+Relu函数的多层感知机\n",
    "'''\n",
    "class MLP(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super(MLP, self).__init__()\n",
    "        self.model = nn.Sequential(\n",
    "            nn.Linear(784, 200),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Linear(200, 200),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Linear(200, 10),\n",
    "            nn.ReLU(inplace=True))\n",
    "    def forward(self, x):\n",
    "        x = self.model(x)\n",
    "        return x\n",
    "\n",
    "'''\n",
    "设置超参数\n",
    "'''\n",
    "batch_size = 128\n",
    "learning_rate = 0.01\n",
    "epochs = 10\n",
    "\n",
    "'''\n",
    "加载数据\n",
    "'''\n",
    "train_loader = torch.utils.data.DataLoader(datasets.MNIST(\n",
    "    'data', # \n",
    "    train=True,\n",
    "   download=True,\n",
    "    transform=transforms.Compose(\n",
    "        [transforms.ToTensor(),\n",
    "         transforms.Normalize((0.1307, ), (0.3081, ))])),\n",
    "                                           batch_size=batch_size,\n",
    "                                           shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(datasets.MNIST(\n",
    "    'data',\n",
    "    train=False,\n",
    "    transform=transforms.Compose(\n",
    "        [transforms.ToTensor(),\n",
    "         transforms.Normalize((0.1307, ), (0.3081, ))])),\n",
    "                                          batch_size=batch_size,\n",
    "                                          shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8261ee3-007e-4829-8293-b98ea67f7754",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 注意此处初始化visdom类\n",
    "vis = visdom.Visdom(env=\"pytorchenv\")\n",
    "\n",
    "# 绘制起点\n",
    "vis.line([0.], [0.], win=\"train loss\", opts=dict(title='train_loss'))\n",
    "device = torch.device('cuda:0')  # 指定GPU\n",
    "net = MLP().to(device)  # 初始化网络\n",
    "optimizer = optim.SGD(net.parameters(), lr=learning_rate)\n",
    "criteon = nn.CrossEntropyLoss().to(device)\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        data = data.view(-1, 28 * 28)\n",
    "        data, target = data.to(device), target.cuda()\n",
    "        logits = net(data)\n",
    "        loss = criteon(logits, target)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        # print(w1.grad.norm(), w2.grad.norm())\n",
    "        optimizer.step()\n",
    "        if batch_idx % 100 == 0:\n",
    "            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
    "                epoch, batch_idx * len(data), len(train_loader.dataset),\n",
    "                100. * batch_idx / len(train_loader), loss.item()))\n",
    "\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    for data, target in test_loader:\n",
    "        data = data.view(-1, 28 * 28)\n",
    "        data, target = data.to(device), target.cuda()\n",
    "        logits = net(data)\n",
    "        test_loss += criteon(logits, target).item()\n",
    "\n",
    "        pred = logits.argmax(dim=1)\n",
    "        correct += pred.eq(target).float().sum().item()\n",
    "\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "    # 绘制epoch以及对应的测试集损失loss\n",
    "    vis.line([test_loss], [epoch], win=\"train loss\", update='append') # win是必须的 \n",
    "    print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
    "    test_loss, correct, len(test_loader.dataset), correct / len(test_loader.dataset)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86bd04b6-7ac3-4855-95d6-ae3a6e1ef189",
   "metadata": {},
   "source": [
    "得到输出如下图所示：\n",
    "<img src=\"img/demo7.png\" style=\"zoom=50%\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "826ffa2d-912c-42d5-bc97-0209c615ae1d",
   "metadata": {},
   "source": [
    "以上内容只是Visdom工具的初步知识，更多有意思的操作等待大家在实际中探索~"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e90308d0-d50e-483d-b719-7ebd6ca2e36a",
   "metadata": {},
   "source": [
    "参考链接：\n",
    "1. [轻松学 Pytorch–Visdom 可视化](https://mp.weixin.qq.com/s/IjfS5E9HE-N8G2-C-ELTVg)\n",
    "2. [可视化工具Visdom的使用](http://t.csdn.cn/IxClS)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorchenv",
   "language": "python",
   "name": "pytorchenv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
